{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from __future__ import absolute_import, division, print_function, unicode_literals\n",
    "\n",
    "import tensorflow as tf\n",
    "import jieba\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.ticker as ticker\n",
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "import unicodedata\n",
    "import re\n",
    "import numpy as np\n",
    "import os\n",
    "import io\n",
    "import time\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "下载数据集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "path_to_file = \"D:/copy/pycharmWorkspace/ts/seq2seq_attention_trans/cmn-eng/cmn.txt\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "将 unicode 文件转换为 ascii"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def unicode_to_ascii(s):\n",
    "    return ''.join(c for c in unicodedata.normalize('NFD', s)\n",
    "                   if unicodedata.category(c) != 'Mn')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def preprocess_sentence(w):\n",
    "    w = unicode_to_ascii(w.lower().strip())\n",
    "\n",
    "    # 在单词与跟在其后的标点符号之间插入一个空格\n",
    "    # 例如： \"he is a boy.\" => \"he is a boy .\"\n",
    "    # 参考：https://stackoverflow.com/questions/3645931/python-padding-punctuation-with-white-spaces-keeping-punctuation\n",
    "    w = re.sub(r\"([?.!,¿])\", r\" \\1 \", w)\n",
    "    w = re.sub(r'[\" \"]+', \" \", w)\n",
    "\n",
    "    # 除了 (a-z, A-Z, \".\", \"?\", \"!\", \",\")，将所有字符替换为空格\n",
    "    #w = re.sub(r\"[^a-zA-Z?.!,¿]+\", \" \", w)\n",
    "    str = \" \"\n",
    "    w = jieba.cut(w)\n",
    "    w = str.join(w)\n",
    "    #删去多余空格\n",
    "    w = ' '.join(w.split())\n",
    "    w = w.rstrip().strip()\n",
    "\n",
    "    # 给句子加上开始和结束标记\n",
    "    # 以便模型知道何时开始和结束预测\n",
    "    w = '<start> ' + w + ' <end>'\n",
    "    print(w)\n",
    "    return w"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "en_sentence = u\"May I borrow this book?\"\n",
    "sp_sentence = u\"¿Puedo tomar prestado este libro?\"\n",
    "print(preprocess_sentence(en_sentence))\n",
    "print(preprocess_sentence(sp_sentence).encode('utf-8'))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "1. 去除重音符号\n",
    "2. 清理句子\n",
    "3. 返回这样格式的单词对：[ENGLISH, SPANISH]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_dataset(path, num_examples):\n",
    "    lines = io.open(path, encoding='UTF-8').read().strip().split('\\n')\n",
    "\n",
    "    word_pairs = [[preprocess_sentence(w) for w in l.split('\\t')] for l in lines[:num_examples]]\n",
    "\n",
    "    return zip(*word_pairs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def max_length(tensor):\n",
    "    return max(len(t) for t in tensor)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def tokenize(lang):\n",
    "    lang_tokenizer = tf.keras.preprocessing.text.Tokenizer(\n",
    "        filters='')\n",
    "    lang_tokenizer.fit_on_texts(lang)\n",
    "\n",
    "    tensor = lang_tokenizer.texts_to_sequences(lang)\n",
    "\n",
    "    tensor = tf.keras.preprocessing.sequence.pad_sequences(tensor,\n",
    "                                                           padding='post')\n",
    "\n",
    "    return tensor, lang_tokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_dataset(path, num_examples=None):\n",
    "    # 创建清理过的输入输出对\n",
    "    targ_lang, inp_lang = create_dataset(path, num_examples)\n",
    "\n",
    "    input_tensor, inp_lang_tokenizer = tokenize(inp_lang)\n",
    "    target_tensor, targ_lang_tokenizer = tokenize(targ_lang)\n",
    "\n",
    "    return input_tensor, target_tensor, inp_lang_tokenizer, targ_lang_tokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 尝试实验不同大小的数据集\n",
    "num_examples = 30000\n",
    "input_tensor, target_tensor, inp_lang, targ_lang = load_dataset(path_to_file, num_examples)\n",
    "\n",
    "# 计算目标张量的最大长度 （max_length）\n",
    "max_length_targ, max_length_inp = max_length(target_tensor), max_length(input_tensor)\n",
    "\n",
    "# 采用 80 - 20 的比例切分训练集和验证集------------------------提问！1.划分是老师讲的2.图像凹3.1e-08与现在的1e-07 4.最下5.大作业框架---------------------\n",
    "input_tensor_train, input_tensor_val, target_tensor_train, target_tensor_val = train_test_split(input_tensor,target_tensor,test_size=0.2)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "BUFFER_SIZE=len(input_tensor_train)\n",
    "BATCH_SIZE = 64\n",
    "steps_per_epoch = len(input_tensor_train) // BATCH_SIZE\n",
    "embedding_dim = 256\n",
    "units = 1024\n",
    "vocab_inp_size = len(inp_lang.word_index) + 1\n",
    "vocab_tar_size = len(targ_lang.word_index) + 1\n",
    "\n",
    "dataset = tf.data.Dataset.from_tensor_slices((input_tensor_train, target_tensor_train)).shuffle(BUFFER_SIZE)\n",
    "dataset = dataset.batch(BATCH_SIZE, drop_remainder=True)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Encoder(tf.keras.Model):\n",
    "    def __init__(self, vocab_size, embedding_dim, enc_units, batch_sz):\n",
    "        super(Encoder, self).__init__()\n",
    "        self.batch_sz = batch_sz\n",
    "        self.enc_units = enc_units\n",
    "        self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)\n",
    "        self.gru = tf.keras.layers.GRU(self.enc_units,\n",
    "                                       return_sequences=True,\n",
    "                                       return_state=True,\n",
    "                                       recurrent_initializer='glorot_uniform')\n",
    "\n",
    "    def call(self, x, hidden):\n",
    "        x = self.embedding(x)\n",
    "        output, state = self.gru(x, initial_state=hidden)\n",
    "        return output, state\n",
    "\n",
    "    def initialize_hidden_state(self):\n",
    "        return tf.zeros((self.batch_sz, self.enc_units))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "encoder = Encoder(vocab_inp_size, embedding_dim, units, BATCH_SIZE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class BahdanauAttention(tf.keras.layers.Layer):\n",
    "    def __init__(self, units):\n",
    "        super(BahdanauAttention, self).__init__()\n",
    "        self.W1 = tf.keras.layers.Dense(units)\n",
    "        self.W2 = tf.keras.layers.Dense(units)\n",
    "        self.V = tf.keras.layers.Dense(1)\n",
    "\n",
    "    def call(self, query, values):\n",
    "        # 隐藏层的形状 == （批大小，隐藏层大小）\n",
    "        # hidden_with_time_axis 的形状 == （批大小，1，隐藏层大小）\n",
    "        # 这样做是为了执行加法以计算分数\n",
    "        hidden_with_time_axis = tf.expand_dims(query, 1)\n",
    "\n",
    "        # 分数的形状 == （批大小，最大长度，1）\n",
    "        # 我们在最后一个轴上得到 1， 因为我们把分数应用于 self.V\n",
    "        # 在应用 self.V 之前，张量的形状是（批大小，最大长度，单位）\n",
    "        score = self.V(tf.nn.tanh(\n",
    "            self.W1(values) + self.W2(hidden_with_time_axis)))\n",
    "\n",
    "        # 注意力权重 （attention_weights） 的形状 == （批大小，最大长度，1）\n",
    "        attention_weights = tf.nn.softmax(score, axis=1)\n",
    "\n",
    "        # 上下文向量 （context_vector） 求和之后的形状 == （批大小，隐藏层大小）\n",
    "        context_vector = attention_weights * values\n",
    "        context_vector = tf.reduce_sum(context_vector, axis=1)\n",
    "\n",
    "        return context_vector, attention_weights"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "attention_layer = BahdanauAttention(10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Decoder(tf.keras.Model):\n",
    "    def __init__(self, vocab_size, embedding_dim, dec_units, batch_sz):\n",
    "        super(Decoder, self).__init__()\n",
    "        self.batch_sz = batch_sz\n",
    "        self.dec_units = dec_units\n",
    "        self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)\n",
    "        self.gru = tf.keras.layers.GRU(self.dec_units,\n",
    "                                       return_sequences=True,\n",
    "                                       return_state=True,\n",
    "                                       recurrent_initializer='glorot_uniform')\n",
    "        self.fc = tf.keras.layers.Dense(vocab_size)\n",
    "\n",
    "        # 用于注意力\n",
    "        self.attention = BahdanauAttention(self.dec_units)\n",
    "\n",
    "    def call(self, x, hidden, enc_output):\n",
    "        # 编码器输出 （enc_output） 的形状 == （批大小，最大长度，隐藏层大小）\n",
    "        context_vector, attention_weights = self.attention(hidden, enc_output)\n",
    "\n",
    "        # x 在通过嵌入层后的形状 == （批大小，1，嵌入维度）\n",
    "        x = self.embedding(x)\n",
    "\n",
    "        # x 在拼接 （concatenation） 后的形状 == （批大小，1，嵌入维度 + 隐藏层大小）\n",
    "        x = tf.concat([tf.expand_dims(context_vector, 1), x], axis=-1)\n",
    "\n",
    "        # 将合并后的向量传送到 GRU\n",
    "        output, state = self.gru(x)\n",
    "\n",
    "        # 输出的形状 == （批大小 * 1，隐藏层大小）\n",
    "        output = tf.reshape(output, (-1, output.shape[2]))\n",
    "\n",
    "        # 输出的形状 == （批大小，vocab）\n",
    "        x = self.fc(output)\n",
    "\n",
    "        return x, state, attention_weights"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "decoder = Decoder(vocab_tar_size, embedding_dim, units, BATCH_SIZE)\n",
    "\n",
    "optimizer = tf.keras.optimizers.Adam()\n",
    "loss_object = tf.keras.losses.SparseCategoricalCrossentropy(\n",
    "    from_logits=True, reduction='none')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def loss_function(real, pred):\n",
    "    mask = tf.math.logical_not(tf.math.equal(real, 0))\n",
    "    loss_ = loss_object(real, pred)\n",
    "\n",
    "    mask = tf.cast(mask, dtype=loss_.dtype)\n",
    "    loss_ *= mask\n",
    "\n",
    "    return tf.reduce_mean(loss_)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "checkpoint_dir = './training_checkpoints'\n",
    "checkpoint_prefix = os.path.join(checkpoint_dir, \"ckpt\")\n",
    "checkpoint = tf.train.Checkpoint(optimizer=optimizer,\n",
    "                                 encoder=encoder,\n",
    "                                 decoder=decoder)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "@tf.function\n",
    "def train_step(inp, targ, enc_hidden):\n",
    "    loss = 0\n",
    "\n",
    "    with tf.GradientTape() as tape:\n",
    "        enc_output, enc_hidden = encoder(inp, enc_hidden)\n",
    "\n",
    "        dec_hidden = enc_hidden\n",
    "\n",
    "        dec_input = tf.expand_dims([targ_lang.word_index['<start>']] * BATCH_SIZE, 1)\n",
    "        # 教师强制 - 将目标词作为下一个输入\n",
    "        for t in range(1, targ.shape[1]):\n",
    "            # 将编码器输出 （enc_output） 传送至解码器\n",
    "            predictions, dec_hidden, _ = decoder(dec_input, dec_hidden, enc_output)\n",
    "\n",
    "            loss += loss_function(targ[:, t], predictions)\n",
    "            # 使用教师强制\n",
    "            dec_input = tf.expand_dims(targ[:, t], 1)\n",
    "\n",
    "    batch_loss = (loss / int(targ.shape[1]))\n",
    "\n",
    "    variables = encoder.trainable_variables + decoder.trainable_variables\n",
    "\n",
    "    gradients = tape.gradient(loss, variables)\n",
    "\n",
    "    optimizer.apply_gradients(zip(gradients, variables))\n",
    "\n",
    "    return batch_loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "EPOCHS = 10\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for epoch in range(EPOCHS):\n",
    "    start = time.time()\n",
    "\n",
    "    enc_hidden = encoder.initialize_hidden_state()\n",
    "    total_loss = 0\n",
    "\n",
    "    for (batch, (inp, targ)) in enumerate(dataset.take(steps_per_epoch)):\n",
    "        batch_loss = train_step(inp, targ, enc_hidden)\n",
    "        total_loss += batch_loss\n",
    "\n",
    "        if batch % 100 == 0:\n",
    "            print('Epoch {} Batch {} Loss {:.4f}'.format(epoch + 1,\n",
    "                                                         batch,\n",
    "                                                         batch_loss.numpy()))\n",
    "    # 每 2 个周期（epoch），保存（检查点）一次模型\n",
    "    if (epoch + 1) % 2 == 0:\n",
    "        checkpoint.save(file_prefix=checkpoint_prefix)\n",
    "\n",
    "    print('Epoch {} Loss {:.4f}'.format(epoch + 1,\n",
    "                                        total_loss / steps_per_epoch))\n",
    "    print('Time taken for 1 epoch {} sec\\n'.format(time.time() - start))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate(sentence):\n",
    "    attention_plot = np.zeros((max_length_targ, max_length_inp))\n",
    "\n",
    "    sentence = preprocess_sentence(sentence)\n",
    "\n",
    "    inputs = [inp_lang.word_index[i] for i in sentence.split(' ')]\n",
    "    inputs = tf.keras.preprocessing.sequence.pad_sequences([inputs],\n",
    "                                                           maxlen=max_length_inp,\n",
    "                                                           padding='post')\n",
    "    inputs = tf.convert_to_tensor(inputs)\n",
    "\n",
    "    result = ''\n",
    "\n",
    "    hidden = [tf.zeros((1, units))]\n",
    "    enc_out, enc_hidden = encoder(inputs, hidden)\n",
    "\n",
    "    dec_hidden = enc_hidden\n",
    "    dec_input = tf.expand_dims([targ_lang.word_index['<start>']], 0)\n",
    "\n",
    "    for t in range(max_length_targ):\n",
    "        predictions, dec_hidden, attention_weights = decoder(dec_input,\n",
    "                                                             dec_hidden,\n",
    "                                                             enc_out)\n",
    "        # 存储注意力权重以便后面制图\n",
    "        attention_weights = tf.reshape(attention_weights, (-1,))\n",
    "        attention_plot[t] = attention_weights.numpy()\n",
    "\n",
    "        predicted_id = tf.argmax(predictions[0]).numpy()\n",
    "\n",
    "        result += targ_lang.index_word[predicted_id] + ' '\n",
    "\n",
    "        if targ_lang.index_word[predicted_id] == '<end>':\n",
    "            return result, sentence, attention_plot\n",
    "        # 预测的 ID 被输送回模型\n",
    "        dec_input = tf.expand_dims([predicted_id], 0)\n",
    "\n",
    "    return result, sentence, attention_plot"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "注意力权重制图函数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_attention(attention, sentence, predicted_sentence):\n",
    "    fig = plt.figure(figsize=(10, 10))\n",
    "    ax = fig.add_subplot(1, 1, 1)\n",
    "    ax.matshow(attention, cmap='viridis')\n",
    "\n",
    "    fontdict = {'fontsize': 14}\n",
    "\n",
    "    ax.set_xticklabels([''] + sentence, fontdict=fontdict, rotation=90)\n",
    "    ax.set_yticklabels([''] + predicted_sentence, fontdict=fontdict)\n",
    "\n",
    "    ax.xaxis.set_major_locator(ticker.MultipleLocator(1))\n",
    "    ax.yaxis.set_major_locator(ticker.MultipleLocator(1))\n",
    "\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def translate(sentence):\n",
    "    result, sentence, attention_plot = evaluate(sentence)\n",
    "\n",
    "    print('Input: %s' % (sentence))\n",
    "    print('Predicted translation: {}'.format(result))\n",
    "\n",
    "    attention_plot = attention_plot[:len(result.split(' ')), :len(sentence.split(' '))]\n",
    "    plot_attention(attention_plot, sentence.split(' '), result.split(' '))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "translate(u'它是一封信')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.7.9",
   "language": "python",
   "name": "python37"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
